/*
 * SPDX-FileCopyrightText: syuilo and misskey-project
 * SPDX-License-Identifier: AGPL-3.0-only
 */

import vertexSource from './snowfall-effect.vertex.glsl';
import fragmentSource from './snowfall-effect.fragment.glsl';

export class SnowfallEffect {
	private VERTEX_SOURCE = vertexSource;
	private FRAGMENT_SOURCE = fragmentSource;

	private gl: WebGLRenderingContext;
	private program: WebGLProgram;
	private canvas: HTMLCanvasElement;
	private buffers: Record<string, {
		size: number;
		value: number[] | Float32Array;
		location: number;
		ref: WebGLBuffer;
	}>;
	private uniforms: Record<string, {
		type: string;
		value: number[] | Float32Array;
		location: WebGLUniformLocation;
	}>;
	private texture: WebGLTexture;
	private camera: {
		fov: number;
		near: number;
		far: number;
		aspect: number;
		z: number;
	};
	private wind: {
		current: number;
		force: number;
		target: number;
		min: number;
		max: number;
		easing: number;
	};
	private time: {
		start: number;
		previous: number;
	} = {
			start: 0,
			previous: 0,
		};
	private raf = 0;

	private density: number = 1 / 90;
	private depth = 100;
	private count = 1000;
	private gravity = 100;
	private speed: number = 1 / 10000;
	private color: number[] = [1, 1, 1];
	private opacity = 1;
	private size = 4;
	private snowflake = '';
	private mode = 'snow';

	private INITIAL_BUFFERS = () => ({
		position: { size: 3, value: [] },
		color: { size: 4, value: [] },
		size: { size: 1, value: [] },
		rotation: { size: 3, value: [] },
		speed: { size: 3, value: [] },
	});

	private INITIAL_UNIFORMS = () => ({
		time: { type: 'float', value: 0 },
		worldSize: { type: 'vec3', value: [0, 0, 0] },
		gravity: { type: 'float', value: this.gravity },
		wind: { type: 'float', value: 0 },
		spin_factor: { type: 'float', value: this.mode === 'sakura' ? 8 : 1 },
		turbulence: { type: 'float', value: this.mode === 'sakura' ? 2 : 1 },
		projection: {
			type: 'mat4',
			value: [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
		},
	});

	private UNIFORM_SETTERS = {
		int: 'uniform1i',
		float: 'uniform1f',
		vec2: 'uniform2fv',
		vec3: 'uniform3fv',
		vec4: 'uniform4fv',
		mat2: 'uniformMatrix2fv',
		mat3: 'uniformMatrix3fv',
		mat4: 'uniformMatrix4fv',
	};

	private CAMERA = {
		fov: 60,
		near: 5,
		far: 10000,
		aspect: 1,
		z: 100,
	};

	private WIND = {
		current: 0,
		force: 0.01,
		target: 0.01,
		min: 0,
		max: 0.125,
		easing: 0.0005,
	};
	/**
	 * @throws {Error} - Thrown when it fails to get WebGL context for the canvas
	 */
	constructor(options: {
		sakura?: boolean;
	}) {
		if (options.sakura) {
			this.mode = 'sakura';
			this.snowflake = '';
			this.size = 10;
			this.density = 1 / 280;
		}

		const canvas = this.initCanvas();
		const gl = canvas.getContext('webgl2', { antialias: true });
		if (gl == null) throw new Error('Failed to get WebGL context');

		window.document.body.append(canvas);

		this.canvas = canvas;
		this.gl = gl;
		this.program = this.initProgram();
		this.buffers = this.initBuffers();
		this.uniforms = this.initUniforms();
		this.texture = this.initTexture();
		this.camera = this.initCamera();
		this.wind = this.initWind();

		this.resize = this.resize.bind(this);
		this.update = this.update.bind(this);

		window.addEventListener('resize', () => this.resize());
	}

	private initCanvas(): HTMLCanvasElement {
		const canvas = window.document.createElement('canvas');

		Object.assign(canvas.style, {
			position: 'fixed',
			top: 0,
			left: 0,
			width: '100vw',
			height: '100vh',
			background: 'transparent',
			'pointer-events': 'none',
			'z-index': 2147483647,
		});

		return canvas;
	}

	private initCamera() {
		return { ...this.CAMERA };
	}

	private initWind() {
		return { ...this.WIND };
	}

	private initShader(type, source): WebGLShader {
		const { gl } = this;
		const shader = gl.createShader(type);
		if (shader == null) throw new Error('Failed to create shader');

		gl.shaderSource(shader, source);
		gl.compileShader(shader);

		return shader;
	}

	private initProgram(): WebGLProgram {
		const { gl } = this;
		const vertex = this.initShader(gl.VERTEX_SHADER, this.VERTEX_SOURCE);
		const fragment = this.initShader(gl.FRAGMENT_SHADER, this.FRAGMENT_SOURCE);
		const program = gl.createProgram();
		if (program == null) throw new Error('Failed to create program');

		gl.attachShader(program, vertex);
		gl.attachShader(program, fragment);
		gl.linkProgram(program);
		gl.useProgram(program);

		return program;
	}

	private initBuffers(): SnowfallEffect['buffers'] {
		const { gl, program } = this;
		const buffers = this.INITIAL_BUFFERS() as unknown as SnowfallEffect['buffers'];

		for (const [name, buffer] of Object.entries(buffers)) {
			buffer.location = gl.getAttribLocation(program, `a_${name}`);
			buffer.ref = gl.createBuffer()!;

			gl.bindBuffer(gl.ARRAY_BUFFER, buffer.ref);
			gl.enableVertexAttribArray(buffer.location);
			gl.vertexAttribPointer(
				buffer.location,
				buffer.size,
				gl.FLOAT,
				false,
				0,
				0,
			);
		}

		return buffers;
	}

	private updateBuffers() {
		const { buffers } = this;

		for (const name of Object.keys(buffers)) {
			this.setBuffer(name);
		}
	}

	private setBuffer(name: string, value?) {
		const { gl, buffers } = this;
		const buffer = buffers[name];

		buffer.value = new Float32Array(value ?? buffer.value);

		gl.bindBuffer(gl.ARRAY_BUFFER, buffer.ref);
		gl.bufferData(gl.ARRAY_BUFFER, buffer.value, gl.STATIC_DRAW);
	}

	private initUniforms(): SnowfallEffect['uniforms'] {
		const { gl, program } = this;
		const uniforms = this.INITIAL_UNIFORMS() as unknown as SnowfallEffect['uniforms'];

		for (const [name, uniform] of Object.entries(uniforms)) {
			uniform.location = gl.getUniformLocation(program, `u_${name}`)!;
		}

		return uniforms;
	}

	private updateUniforms() {
		const { uniforms } = this;

		for (const name of Object.keys(uniforms)) {
			this.setUniform(name);
		}
	}

	private setUniform(name: string, value?) {
		const { gl, uniforms } = this;
		const uniform = uniforms[name];
		const setter = this.UNIFORM_SETTERS[uniform.type];
		const isMatrix = /^mat[2-4]$/i.test(uniform.type);

		uniform.value = value ?? uniform.value;

		if (isMatrix) {
			gl[setter](uniform.location, false, uniform.value);
		} else {
			gl[setter](uniform.location, uniform.value);
		}
	}

	private initTexture() {
		const { gl } = this;
		const texture = gl.createTexture();
		if (texture == null) throw new Error('Failed to create texture');
		const image = new Image();

		gl.bindTexture(gl.TEXTURE_2D, texture);
		gl.texImage2D(
			gl.TEXTURE_2D,
			0,
			gl.RGBA,
			1,
			1,
			0,
			gl.RGBA,
			gl.UNSIGNED_BYTE,
			new Uint8Array([0, 0, 0, 0]),
		);

		image.onload = () => {
			gl.bindTexture(gl.TEXTURE_2D, texture);
			gl.texImage2D(
				gl.TEXTURE_2D,
				0,
				gl.RGBA,
				gl.RGBA,
				gl.UNSIGNED_BYTE,
				image,
			);
			gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
			gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
			gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
			gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
		};

		image.src = this.snowflake;

		return texture;
	}

	private initSnowflakes(vw: number, vh: number, dpi: number) {
		const position: number[] = [];
		const color: number[] = [];
		const size: number[] = [];
		const rotation: number[] = [];
		const speed: number[] = [];

		const height = 1 / this.density;
		const width = (vw / vh) * height;
		const depth = this.depth;
		const count = this.count;
		const length = (vw / vh) * count;

		for (let i = 0; i < length; ++i) {
			position.push(
				-width + Math.random() * width * 2,
				-height + Math.random() * height * 2,
				Math.random() * depth * 2,
			);

			speed.push(1 + Math.random(), 1 + Math.random(), Math.random() * 10);

			rotation.push(
				Math.random() * 2 * Math.PI,
				Math.random() * 20,
				Math.random() * 10,
			);

			color.push(...this.color, 0.1 + Math.random() * this.opacity);
			//size.push((this.size * Math.random() * this.size * vh * dpi) / 1000);
			size.push((this.size * vh * dpi) / 1000);
		}

		this.setUniform('worldSize', [width, height, depth]);

		this.setBuffer('position', position);
		this.setBuffer('color', color);
		this.setBuffer('rotation', rotation);
		this.setBuffer('size', size);
		this.setBuffer('speed', speed);
	}

	private setProjection(aspect: number) {
		const { camera } = this;

		camera.aspect = aspect;

		const fovRad = (camera.fov * Math.PI) / 180;
		const f = Math.tan(Math.PI * 0.5 - 0.5 * fovRad);
		const rangeInv = 1.0 / (camera.near - camera.far);

		const m0 = f / camera.aspect;
		const m5 = f;
		const m10 = (camera.near + camera.far) * rangeInv;
		const m11 = -1;
		const m14 = camera.near * camera.far * rangeInv * 2 + camera.z;
		const m15 = camera.z;

		return [m0, 0, 0, 0, 0, m5, 0, 0, 0, 0, m10, m11, 0, 0, m14, m15];
	}

	public render() {
		const { gl } = this;

		gl.enable(gl.BLEND);
		gl.enable(gl.CULL_FACE);
		gl.blendFunc(gl.SRC_ALPHA, gl.ONE);
		gl.disable(gl.DEPTH_TEST);

		this.updateBuffers();
		this.updateUniforms();
		this.resize(true);

		this.time = {
			start: window.performance.now(),
			previous: window.performance.now(),
		};

		if (this.raf) window.cancelAnimationFrame(this.raf);
		this.raf = window.requestAnimationFrame(this.update);

		return this;
	}

	private resize(updateSnowflakes = false) {
		const { canvas, gl } = this;
		const vw = canvas.offsetWidth;
		const vh = canvas.offsetHeight;
		const aspect = vw / vh;
		const dpi = window.devicePixelRatio;

		canvas.width = vw * dpi;
		canvas.height = vh * dpi;

		gl.viewport(0, 0, vw * dpi, vh * dpi);
		gl.clearColor(0, 0, 0, 0);

		if (updateSnowflakes === true) {
			this.initSnowflakes(vw, vh, dpi);
		}

		this.setUniform('projection', this.setProjection(aspect));
	}

	private update(timestamp: number) {
		const { gl, buffers, wind } = this;
		const elapsed = (timestamp - this.time.start) * this.speed;
		const delta = timestamp - this.time.previous;

		gl.clear(gl.COLOR_BUFFER_BIT);
		gl.drawArrays(
			gl.POINTS,
			0,
			buffers.position.value.length / buffers.position.size,
		);

		if (Math.random() > 0.995) {
			wind.target =
				(wind.min + Math.random() * (wind.max - wind.min)) *
				(Math.random() > 0.5 ? -1 : 1);
		}

		wind.force += (wind.target - wind.force) * wind.easing;
		wind.current += wind.force * (delta * 0.2);

		this.setUniform('wind', wind.current);
		this.setUniform('time', elapsed);

		this.time.previous = timestamp;

		this.raf = window.requestAnimationFrame(this.update);
	}
}
